{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine-tune classifier with ModernBERT in 2025\n",
    "\n",
    "Large Language Models (LLMs) have become ubiquitous in 2024. However, smaller, specialized models - particularly for classification tasks - remain critical for building efficient and cost-effective AI systems. One key use case is routing user prompts to the most appropriate LLM or selecting optimal few-shot examples, where fast, accurate classification is essential.\n",
    "\n",
    "This blog post demonstrates how to fine-tune ModernBERT, a new state-of-the-art encoder model, for classifying user prompts to implement an intelligent LLM router. ModernBERT is a refreshed version of BERT models, with 8192 token context length, significantly better downstream performance, and much faster processing speeds.\n",
    "\n",
    "You will learn how to:\n",
    "1. Setup environment and install libraries\n",
    "2. Load and prepare the classification dataset  \n",
    "3. Fine-tune & evaluate ModernBERT with the Hugging Face `Trainer`\n",
    "4. Run inference & test model\n",
    "\n",
    "## Quick intro: ModernBERT\n",
    "\n",
    "ModernBERT is a modernization of BERT maintaining full backward compatibility while delivering dramatic improvements through architectural innovations like rotary positional embeddings (RoPE), alternating attention patterns, and hardware-optimized design. The model comes in two sizes:\n",
    "- ModernBERT Base (139M parameters)\n",
    "- ModernBERT Large (395M parameters)\n",
    "\n",
    "ModernBERT achieves state-of-the-art performance across classification, retrieval and code understanding tasks while being 2-4x faster than previous encoder models. This makes it ideal for high-throughput production applications like LLM routing, where both accuracy and latency are critical.\n",
    "\n",
    "ModernBERT was trained on 2 trillion tokens of diverse data including web documents, code, and scientific articles - making it much more robust than traditional BERT models trained primarily on Wikipedia. This broader knowledge helps it better understand the nuances of user prompts across different domains.\n",
    "\n",
    "If you want to learn more about ModernBERT's architecture and training process, check out the official [blog](https://huggingface.co/blog/modernbert). \n",
    "\n",
    "---\n",
    "\n",
    "Now let's get started building our LLM router with ModernBERT! 🚀\n",
    "\n",
    "*Note: This tutorial was created and tested on an NVIDIA L4 GPU with 24GB of VRAM.*\n",
    "\n",
    "## Setup environment and install libraries\n",
    "\n",
    "Our first step is to install Hugging Face Libraries and Pyroch, including transformers and datasets. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install Pytorch & other libraries\n",
    "%pip install \"torch==2.4.1\" tensorboard \n",
    "%pip install flash-attn \"setuptools<71.0.0\" scikit-learn \n",
    "\n",
    "# Install Hugging Face libraries\n",
    "%pip install  --upgrade \\\n",
    "  \"datasets==3.1.0\" \\\n",
    "  \"accelerate==1.2.1\" \\\n",
    "  \"hf-transfer==0.1.8\"\n",
    "  #\"transformers==4.47.1\" \\\n",
    "\n",
    "# ModernBERT is not yet available in an official release, so we need to install it from github\n",
    "%pip install \"git+https://github.com/huggingface/transformers.git@6e0515e99c39444caae39472ee1b2fd76ece32f1\" --upgrade\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will use the [Hugging Face Hub](https://huggingface.co/models) as a remote model versioning service. This means we will automatically push our model, logs and information to the Hub during training. You must register on the [Hugging Face](https://huggingface.co/join) for this. After you have an account, we will use the `login` util from the `huggingface_hub` package to log into our account and store our token (access key) on the disk."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import login\n",
    "\n",
    "login(token=\"\", add_to_git_credential=True) # ADD YOUR TOKEN HERE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Load and prepare the dataset\n",
    "\n",
    "In our example we want to fine-tune ModernBERT to act as a router for user prompts. Therefore we need a classification dataset consisting of user prompts and their \"difficulty\" score. We are going to use the `DevQuasar/llm_router_dataset-synth` dataset, which is a synthetic dataset of ~15,000 user prompts with a difficulty score of \"large_llm\" (`1`) or \"small_llm\" (`0`). \n",
    "\n",
    "\n",
    "We will use the `load_dataset()` method from the [🤗 Datasets](https://huggingface.co/docs/datasets/index) library to load the `DevQuasar/llm_router_dataset-synth` dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "# Dataset id from huggingface.co/dataset\n",
    "dataset_id = \"DevQuasar/llm_router_dataset-synth\"\n",
    "\n",
    "# Load raw dataset\n",
    "raw_dataset = load_dataset(dataset_id)\n",
    "\n",
    "print(f\"Train dataset size: {len(raw_dataset['train'])}\")\n",
    "print(f\"Test dataset size: {len(raw_dataset['test'])}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let’s check out an example of the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from random import randrange\n",
    "\n",
    "random_id = randrange(len(raw_dataset['train']))\n",
    "raw_dataset['train'][random_id]\n",
    "# {'id': '6225a9cd-5cba-4840-8e21-1f9cf2ded7e6',\n",
    "# 'prompt': 'How many legs does a spider have?',\n",
    "# 'label': 0}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To train our model, we need to convert our text prompts to token IDs. This is done by a Tokenizer, which tokenizes the inputs (including converting the tokens to their corresponding IDs in the pre-trained vocabulary) if you want to learn more about this, out **[chapter 6](https://huggingface.co/course/chapter6/1?fw=pt)** of the [Hugging Face Course](https://huggingface.co/course/chapter1/1)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "# Model id to load the tokenizer\n",
    "model_id = \"answerdotai/ModernBERT-base\"\n",
    "\n",
    "# Load Tokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
    "tokenizer.model_max_length = 512 # set model_max_length to 512 as prompts are not longer than 1024 tokens\n",
    "\n",
    "# Tokenize helper function\n",
    "def tokenize(batch):\n",
    "    return tokenizer(batch['text'], padding='max_length', truncation=True, return_tensors=\"pt\")\n",
    "\n",
    "# Tokenize dataset\n",
    "raw_dataset =  raw_dataset.rename_column(\"label\", \"labels\") # to match Trainer\n",
    "tokenized_dataset = raw_dataset.map(tokenize, batched=True,remove_columns=[\"text\"])\n",
    "\n",
    "print(tokenized_dataset[\"train\"].features.keys())\n",
    "# dict_keys(['input_ids', 'token_type_ids', 'attention_mask','lable'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Fine-tune & evaluate ModernBERT with the Hugging Face `Trainer`\n",
    "\n",
    "After we have processed our dataset, we can start training our model. We will use the [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) model. The first step is to load our model with `AutoModelForSequenceClassification` class from the [Hugging Face Hub](https://huggingface.co/answerdotai/ModernBERT-base). This will initialize the pre-trained ModernBERT weights with a classification head on top. Here we pass the number of classes (2) from our dataset and the label names to have readable outputs for inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForSequenceClassification\n",
    "\n",
    "# Model id to load the tokenizer\n",
    "model_id = \"answerdotai/ModernBERT-base\"\n",
    "\n",
    "# Prepare model labels - useful for inference\n",
    "labels = tokenized_dataset[\"train\"].features[\"labels\"].names\n",
    "num_labels = len(labels)\n",
    "label2id, id2label = dict(), dict()\n",
    "for i, label in enumerate(labels):\n",
    "    label2id[label] = str(i)\n",
    "    id2label[str(i)] = label\n",
    "\n",
    "# Download the model from huggingface.co/models\n",
    "model = AutoModelForSequenceClassification.from_pretrained(\n",
    "    model_id, num_labels=num_labels, label2id=label2id, id2label=id2label,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We evaluate our model during training. The `Trainer` supports evaluation during training by providing a `compute_metrics` method. We use the `evaluate` library to calculate the [f1 metric](https://huggingface.co/spaces/evaluate-metric/f1) during training on our test split."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "# Metric helper method\n",
    "def compute_metrics(eval_pred):\n",
    "    predictions, labels = eval_pred\n",
    "    predictions = np.argmax(predictions, axis=1)\n",
    "    score = f1_score(\n",
    "            labels, predictions, labels=labels, pos_label=1, average=\"weighted\"\n",
    "        )\n",
    "    return {\"f1\": float(score) if score == 1 else score}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The last step is to define the hyperparameters (`TrainingArguments`) we use for our training. Here we are adding optimizations introduced features for fast training times using `torch_compile` option in the `TrainingArguments`.\n",
    "\n",
    "We also leverage the [Hugging Face Hub](https://huggingface.co/models) integration of the `Trainer` to push our checkpoints, logs, and metrics during training into a repository."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import HfFolder\n",
    "from transformers import Trainer, TrainingArguments\n",
    "\n",
    "# Define training args\n",
    "training_args = TrainingArguments(\n",
    "    output_dir= \"modernbert-llm-router\",\n",
    "    per_device_train_batch_size=32,\n",
    "    per_device_eval_batch_size=16,\n",
    "    learning_rate=5e-5,\n",
    "\t\tnum_train_epochs=5,\n",
    "    bf16=True, # bfloat16 training \n",
    "    optim=\"adamw_torch_fused\", # improved optimizer \n",
    "    # logging & evaluation strategies\n",
    "    logging_strategy=\"steps\",\n",
    "    logging_steps=100,\n",
    "    eval_strategy=\"epoch\",\n",
    "    save_strategy=\"epoch\",\n",
    "    save_total_limit=2,\n",
    "    load_best_model_at_end=True,\n",
    "    metric_for_best_model=\"f1\",\n",
    "    # push to hub parameters\n",
    "    report_to=\"tensorboard\",\n",
    "    push_to_hub=True,\n",
    "    hub_strategy=\"every_save\",\n",
    "    hub_token=HfFolder.get_token(),\n",
    "\n",
    ")\n",
    "\n",
    "# Create a Trainer instance\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_dataset[\"train\"],\n",
    "    eval_dataset=tokenized_dataset[\"test\"],\n",
    "    compute_metrics=compute_metrics,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can start our training by using the **`train`** method of the `Trainer`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Start training\n",
    "trainer.train()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Fine-tuning `answerdotai/ModernBERT-base` on ~15,000 synthetic prompts for 5 epochs took `321` seconds and our best model achieved a `f1` score of `0.993`. 🚀 I also ran the training with `bert-base-uncased` to compare the training time and performance. The original BERT achieved a `f1` score of `0.99` and took `1048` seconds to train. \n",
    "\n",
    "*Note: ModernBERT and BERT both almost achieve the same performance. This indicates that the dataset is not challenging and probably could be solved using a logistic regression classifier. I ran the same code on the [banking77](https://huggingface.co/datasets/legacy-datasets/banking77) dataset. A dataset of ~13,000 customer service queries with 77 classes. There the ModernBERT outperformed the original BERT by 3% (f1 score of 0.93 vs 0.90)*\n",
    "\n",
    "\n",
    "Lets save our final best model and tokenizer to the Hugging Face Hub and create a model card."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save processor and create model card\n",
    "tokenizer.save_pretrained(\"modernbert-llm-router\")\n",
    "trainer.create_model_card()\n",
    "trainer.push_to_hub()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Run Inference & test model\n",
    "\n",
    "To wrap up this tutorial, we will run inference on a few examples and test our model. We will use the `pipeline` method from the `transformers` library to run inference on our model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "# load model from huggingface.co/models using our repository id\n",
    "classifier = pipeline(\"sentiment-analysis\", model=\"modernbert-llm-router\", device=0)\n",
    "\n",
    "sample = \"How does the structure and function of plasmodesmata affect cell-to-cell communication and signaling in plant tissues, particularly in response to environmental stresses?\"\n",
    "\n",
    "\n",
    "pred = classifier(sample)\n",
    "print(pred)\n",
    "# [{'label': 'large_llm', 'score': 1.0}]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "In this tutorial, we learned how to fine-tune ModernBERT for an LLM routing classification task. We demonstrated how to leverage the Hugging Face ecosystem to efficiently train and deploy a specialized classifier that can intelligently route user prompts to the most appropriate LLM model.\n",
    "\n",
    "Using modern training optimizations like flash attention, fused optimizers and mixed precision, we were able to train our model efficiently. Comparing ModernBERT with the original BERT we reduced training time by approximately 3x (1048s vs 321s) on our dataset and outperformed the original BERT by 3% on a more challenging dataset. But more importantly, ModernBERT was trained on 2 trillion tokens, which are more diverse and up to date than the Wikipedia-based training data of the original BERT.\n",
    "\n",
    "This example showcases how smaller, specialized models remain valuable in the age of large language models - particularly for high-throughput, latency-sensitive tasks like LLM routing. By using ModernBERT's improved architecture and broader training data, we can build more robust and efficient classification systems."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "2d58e898dde0263bc564c6968b04150abacfd33eed9b19aaa8e45c040360e146"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
